{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-level Keras R (TF) CNN Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# SETUP\n",
    "#\n",
    "# Install keras R\n",
    "# install.packages('keras', repos = \"https://cloud.r-project.org\")\n",
    "# \n",
    "# Update reticulate from cran (it defaults to mran which has an outdated version)\n",
    "# install.packages(\"reticulate\", repos = \"https://cloud.r-project.org\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading required package: rjson\n"
     ]
    }
   ],
   "source": [
    "library(keras)\n",
    "use_python('/anaconda/envs/py35')\n",
    "\n",
    "# Import util functions\n",
    "source(\"./common/utils.R\")\n",
    "\n",
    "# Import hyper-parameters\n",
    "params <- load_params(\"cnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Performance Improvement\n",
    "# 1. Make sure channels-first (not last)\n",
    "backend()$set_image_data_format('channels_first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# reticulate::py_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS: Linux \n",
      "R version 3.4.1 (2017-06-30) \n",
      "Keras: 2.1.5 \n",
      "Tensorflow: 1.5 \n",
      "Keras using tensorflow \n",
      "Keras channel ordering is channels_first \n",
      "GPU:  Tesla P100-PCIE-16GB \n",
      "CUDA Version 8.0.61 \n",
      "CuDNN Version 6.0.21 \n"
     ]
    }
   ],
   "source": [
    "cat(\"OS:\", Sys.info()[\"sysname\"], \"\\n\")\n",
    "cat(R.version$version.string, \"\\n\")\n",
    "cat(\"Keras:\", paste0(packageVersion(\"keras\")), \"\\n\")\n",
    "cat(\"Tensorflow:\", paste0(packageVersion(\"tensorflow\")), \"\\n\")\n",
    "cat(\"Keras using\", backend()$backend(), \"\\n\")\n",
    "cat(\"Keras channel ordering is\", backend()$image_data_format(), \"\\n\") \n",
    "cat(\"GPU: \", get_gpu_name(), \"\\n\")\n",
    "cat(get_cuda_version(), \"\\n\")\n",
    "cat(get_cudnn_version(), \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "create_symbol <- function(n_classes = params$N_CLASSES){\n",
    "    \n",
    "    # initialize a sequential model\n",
    "    model <- keras_model_sequential() %>%\n",
    "    \n",
    "    # 2D convolutional layers being fed 32x32 pixel images   \n",
    "    layer_conv_2d(filters = 50, kernel_size = c(3, 3), padding = \"same\", activation = \"relu\", input_shape = c(3, 32, 32)) %>%\n",
    "    layer_conv_2d(filters = 50, kernel_size = c(3, 3), padding = \"same\", activation = \"relu\") %>%\n",
    "    \n",
    "    # max pooling\n",
    "    layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2)) %>%\n",
    "    layer_dropout(0.25) %>%\n",
    "\n",
    "    # 2D convolutional layers \n",
    "    layer_conv_2d(filters = 100, kernel_size = c(3, 3), padding = \"same\", activation = \"relu\") %>%\n",
    "    layer_conv_2d(filters = 100, kernel_size = c(3, 3), padding = \"same\", activation = \"relu\") %>%\n",
    "    \n",
    "    # max pooling\n",
    "    layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2)) %>%\n",
    "    layer_dropout(0.25) %>%\n",
    "\n",
    "    # flatten into feature vector\n",
    "    layer_flatten() %>%\n",
    "    layer_dense(512, activation = \"relu\") %>%\n",
    "    layer_dropout(0.5) %>%\n",
    "    layer_dense(n_classes, activation = \"softmax\")  \n",
    "    \n",
    "    return(model)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "init_model <- function(m, lr=params$LR, momentum=params$MOMENTUM){\n",
    "    m %>% compile(\n",
    "      loss = \"categorical_crossentropy\",\n",
    "      optimizer = optimizer_sgd(lr, momentum),\n",
    "      metrics = \"accuracy\"\n",
    "    )\n",
    "    return(m)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] \"Data does not exist. Downloading https://ikpublictutorial.blob.core.windows.net/deeplearningframeworks/cifar-10-binary.tar.gz \"\n",
      "[1] \"Extracting files ...\"\n"
     ]
    }
   ],
   "source": [
    "# Data Preparation \n",
    "cifar <- cifar_for_library(one_hot = TRUE, col_major = FALSE)\n",
    "x_train <- cifar$x_train\n",
    "y_train <- cifar$y_train\n",
    "x_test <- cifar$x_test\n",
    "y_test <- cifar$y_test\n",
    "\n",
    "rm(cifar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: 50000 3 32 32 \n",
      "x_test shape: 10000 3 32 32 \n",
      "y_train shape: 50000 10 \n",
      "y_test shape: 10000 10 \n"
     ]
    }
   ],
   "source": [
    "cat('x_train shape:', dim(x_train), '\\n')\n",
    "cat('x_test shape:', dim(x_test), '\\n')\n",
    "cat('y_train shape:', dim(y_train), '\\n')\n",
    "cat('y_test shape:', dim(y_test), '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Load symbol\n",
    "sym = create_symbol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Initialise model\n",
    "model = init_model(sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "________________________________________________________________________________\n",
      "Layer (type)                        Output Shape                    Param #     \n",
      "================================================================================\n",
      "conv2d_1 (Conv2D)                   (None, 50, 32, 32)              1400        \n",
      "________________________________________________________________________________\n",
      "conv2d_2 (Conv2D)                   (None, 50, 32, 32)              22550       \n",
      "________________________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2D)      (None, 50, 16, 16)              0           \n",
      "________________________________________________________________________________\n",
      "dropout_1 (Dropout)                 (None, 50, 16, 16)              0           \n",
      "________________________________________________________________________________\n",
      "conv2d_3 (Conv2D)                   (None, 100, 16, 16)             45100       \n",
      "________________________________________________________________________________\n",
      "conv2d_4 (Conv2D)                   (None, 100, 16, 16)             90100       \n",
      "________________________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2D)      (None, 100, 8, 8)               0           \n",
      "________________________________________________________________________________\n",
      "dropout_2 (Dropout)                 (None, 100, 8, 8)               0           \n",
      "________________________________________________________________________________\n",
      "flatten_1 (Flatten)                 (None, 6400)                    0           \n",
      "________________________________________________________________________________\n",
      "dense_1 (Dense)                     (None, 512)                     3277312     \n",
      "________________________________________________________________________________\n",
      "dropout_3 (Dropout)                 (None, 512)                     0           \n",
      "________________________________________________________________________________\n",
      "dense_2 (Dense)                     (None, 10)                      5130        \n",
      "================================================================================\n",
      "Total params: 3,441,592\n",
      "Trainable params: 3,441,592\n",
      "Non-trainable params: 0\n",
      "________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "   user  system elapsed \n",
       " 69.565  14.586  71.634 "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Main training loop\n",
    "system.time(\n",
    "    model %>% fit(\n",
    "        x_train, y_train,\n",
    "        batch_size = params$BATCHSIZE,\n",
    "        epochs = params$EPOCHS)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Main evaluation loop\n",
    "y_guess <- model %>% predict_classes(x_test, batch_size = params$BATCHSIZE)\n",
    "y_truth <- apply(y_test, 1, function(x) which.max(x)-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] \"Accuracy: 0.7685\"\n"
     ]
    }
   ],
   "source": [
    "print(paste0(\"Accuracy: \", sum(y_guess == y_truth)/length(y_guess)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.4.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
